package com.alibaba.alink.operator.common.clustering.lda;

import org.apache.flink.api.java.tuple.Tuple2;

import com.alibaba.alink.common.linalg.DenseMatrix;
import com.alibaba.alink.common.linalg.SparseVector;
import com.alibaba.alink.testutil.AlinkTestBase;
import org.junit.Assert;
import org.junit.Test;

public class LdaUtilTest extends AlinkTestBase {

	@Test
	public void dirichletExpectationTest() {
		double[] sparkCmp = new double[] {
			0.8936825549031158,
			0.9650683744577933,
			1.1760851442955271,
			0.889011463028263,
			1.0355502890838704,
			1.1720254142865503,
			0.8496512959061578,
			1.1564109073902848,
			0.8528198328651976,
			1.072261907065107,
			1.0112487630821958,
			1.0288027427394206,
			1.1256918577237478,
			1.0641131417250107,
			0.9830788207753957,
			0.9519235842178695,
			1.0531103642783968,
			1.0846663792488604,
			0.9317316401779444,
			0.9816247167440154,
			0.953061129524052,
			0.8836097897537777,
			0.8539728772760822,
			1.109432137460693,
			0.9801693423689286,
			0.9385725168762017,
			1.009886079821316,
			0.9741390218380398,
			0.8734624459614093,
			0.8548583255850564,
			0.8934120594879987,
			1.0200469492393616,
			0.9461610896051537,
			1.1912819895664948,
			0.9650275833536232,
			0.9312815665885328,
			0.984681817963758,
			1.1412711858668625,
			1.1159082714127344,
			1.0219124026668207,
			1.1052645047308647,
			1.1380919062139254,
			0.9684793634316371,
			1.023922805813918,
			1.0777999541431174,
			0.8730213177341947,
			1.0353598060502658,
			1.047104264664753,
			1.1284793487722498,
			0.8898021261569816,
			1.1634869627283706,
			0.817874601150865,
			1.0424867867765728,
			1.167773175905418,
			0.915224402643435};
		DenseMatrix lambda = new DenseMatrix(11, 5, sparkCmp, false);

		DenseMatrix expElogbeta = LdaUtil.dirichletExpectation(lambda);

		Assert.assertTrue(Math.abs(expElogbeta.get(0, 0) + 2.2832832787919575) < 10e-4);
	}

	@Test
	public void expDirichletExpectationTest() {
		double[] sparkCmp = new double[] {
			0.8936825549031158,
			0.9650683744577933,
			1.1760851442955271,
			0.889011463028263,
			1.0355502890838704,
			1.1720254142865503,
			0.8496512959061578,
			1.1564109073902848,
			0.8528198328651976,
			1.072261907065107,
			1.0112487630821958,
			1.0288027427394206,
			1.1256918577237478,
			1.0641131417250107,
			0.9830788207753957,
			0.9519235842178695,
			1.0531103642783968,
			1.0846663792488604,
			0.9317316401779444,
			0.9816247167440154,
			0.953061129524052,
			0.8836097897537777,
			0.8539728772760822,
			1.109432137460693,
			0.9801693423689286,
			0.9385725168762017,
			1.009886079821316,
			0.9741390218380398,
			0.8734624459614093,
			0.8548583255850564,
			0.8934120594879987,
			1.0200469492393616,
			0.9461610896051537,
			1.1912819895664948,
			0.9650275833536232,
			0.9312815665885328,
			0.984681817963758,
			1.1412711858668625,
			1.1159082714127344,
			1.0219124026668207,
			1.1052645047308647,
			1.1380919062139254,
			0.9684793634316371,
			1.023922805813918,
			1.0777999541431174,
			0.8730213177341947,
			1.0353598060502658,
			1.047104264664753,
			1.1284793487722498,
			0.8898021261569816,
			1.1634869627283706,
			0.817874601150865,
			1.0424867867765728,
			1.167773175905418,
			0.915224402643435};
		DenseMatrix lambda = new DenseMatrix(11, 5, sparkCmp, false);

		DenseMatrix expElogbeta = LdaUtil.expDirichletExpectation(lambda);

		Assert.assertTrue(Math.abs(expElogbeta.get(0, 0) - 0.101948) < 10e-4);
	}

	@Test
	public void getTopicDistributionMethodTest() {

		int row = 11;
		int col = 5;

		double[] temp = new double[] {0.8936825549031158,
			0.9650683744577933,
			1.1760851442955271,
			0.889011463028263,
			1.0355502890838704,
			1.1720254142865503,
			0.8496512959061578,
			1.1564109073902848,
			0.8528198328651976,
			1.072261907065107,
			1.0112487630821958,
			1.0288027427394206,
			1.1256918577237478,
			1.0641131417250107,
			0.9830788207753957,
			0.9519235842178695,
			1.0531103642783968,
			1.0846663792488604,
			0.9317316401779444,
			0.9816247167440154,
			0.953061129524052,
			0.8836097897537777,
			0.8539728772760822,
			1.109432137460693,
			0.9801693423689286,
			0.9385725168762017,
			1.009886079821316,
			0.9741390218380398,
			0.8734624459614093,
			0.8548583255850564,
			0.8934120594879987,
			1.0200469492393616,
			0.9461610896051537,
			1.1912819895664948,
			0.9650275833536232,
			0.9312815665885328,
			0.984681817963758,
			1.1412711858668625,
			1.1159082714127344,
			1.0219124026668207,
			1.1052645047308647,
			1.1380919062139254,
			0.9684793634316371,
			1.023922805813918,
			1.0777999541431174,
			0.8730213177341947,
			1.0353598060502658,
			1.047104264664753,
			1.1284793487722498,
			0.8898021261569816,
			1.1634869627283706,
			0.817874601150865,
			1.0424867867765728,
			1.167773175905418,
			0.915224402643435};

		DenseMatrix lambda = new DenseMatrix(row, col, temp, false);

		DenseMatrix expElogbeta = LdaUtil.expDirichletExpectation(lambda);

		DenseMatrix alpha = new DenseMatrix(5, 1, new double[] {0.2, 0.3, 0.4, 0.5, 0.6});

		DenseMatrix gammad = new DenseMatrix(5, 1, new double[] {0.7, 0.8, 0.9, 1.0, 1.1});

		///////////////////////////////////////////////////////////////
		SparseVector sv = new SparseVector(11, new int[] {0, 1, 2, 4, 5, 6, 7, 10},
			new double[] {1.0, 2.0, 6.0, 2.0, 3.0, 1.0, 1.0, 3.0});

		Tuple2 <DenseMatrix, DenseMatrix> re = LdaUtil.getTopicDistributionMethod(sv, expElogbeta, alpha, gammad, 5);

		System.out.println(re);

		Assert.assertTrue(Math.abs(re.f0.get(3, 0) - 1.6055989357674745) < 10e-4);
		Assert.assertTrue(Math.abs(re.f1.get(2, 2) - 0.39534340684397445) < 10e-4);
	}

	@Test
	public void getTopicDistributionMethodTest2() {

		int row = 11;
		int col = 5;

		double[] temp = new double[] {0.8936825549031158,
			0.9650683744577933,
			1.1760851442955271,
			0.889011463028263,
			1.0355502890838704,
			1.1720254142865503,
			0.8496512959061578,
			1.1564109073902848,
			0.8528198328651976,
			1.072261907065107,
			1.0112487630821958,
			1.0288027427394206,
			1.1256918577237478,
			1.0641131417250107,
			0.9830788207753957,
			0.9519235842178695,
			1.0531103642783968,
			1.0846663792488604,
			0.9317316401779444,
			0.9816247167440154,
			0.953061129524052,
			0.8836097897537777,
			0.8539728772760822,
			1.109432137460693,
			0.9801693423689286,
			0.9385725168762017,
			1.009886079821316,
			0.9741390218380398,
			0.8734624459614093,
			0.8548583255850564,
			0.8934120594879987,
			1.0200469492393616,
			0.9461610896051537,
			1.1912819895664948,
			0.9650275833536232,
			0.9312815665885328,
			0.984681817963758,
			1.1412711858668625,
			1.1159082714127344,
			1.0219124026668207,
			1.1052645047308647,
			1.1380919062139254,
			0.9684793634316371,
			1.023922805813918,
			1.0777999541431174,
			0.8730213177341947,
			1.0353598060502658,
			1.047104264664753,
			1.1284793487722498,
			0.8898021261569816,
			1.1634869627283706,
			0.817874601150865,
			1.0424867867765728,
			1.167773175905418,
			0.915224402643435};

		DenseMatrix lambda = new DenseMatrix(row, col, temp, false).transpose();

		System.out.println(lambda);

		DenseMatrix expElogbeta = LdaUtil.expDirichletExpectation(lambda).transpose();

		DenseMatrix alpha = new DenseMatrix(5, 1, new double[] {0.2, 0.2, 0.2, 0.2, 0.2});

		DenseMatrix gammad = new DenseMatrix(5, 1, new double[] {0.7, 0.8, 0.9, 1.0, 1.1});

		///////////////////////////////////////////////////////////////
		//        SparseVector sv = new SparseVector(11, new int[]{0, 1, 3, 4, 7, 10}, new double[]{1.0, 3.0, 1.0,
		// 3.0, 2.0, 1.0});
		//        SparseVector sv = new SparseVector(11, new int[]{0, 1, 2, 4, 5, 6, 7, 10}, new double[]{1.0, 2.0,
		// 6.0, 2.0, 3.0, 1.0, 1.0, 3.0});

		//        SparseVector sv = new SparseVector(11, new int[]{0, 1, 3, 4, 7, 10}, new double[]{1.0, 3.0, 1.0,
		// 3.0, 2.0, 1.0});
		SparseVector sv = new SparseVector(11, new int[] {0, 1, 3, 6, 8, 9, 10},
			new double[] {2.0, 1.0, 3.0, 5.0, 2.0, 2.0, 9.0});

		Tuple2 <DenseMatrix, DenseMatrix> re = LdaUtil.getTopicDistributionMethod(sv, expElogbeta, alpha, gammad, 5);

		System.out.println(re);
	}

	@Test
	public void digammaTest() {
		double r = LdaUtil.digamma(1e-5);
		System.out.println(r);
		double r1 = LdaUtil.digamma(1e-100);
		System.out.println(r1);
	}
}